{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# How to force tool calling behavior\n",
        "\n",
        "```{=mdx}\n",
        "\n",
        ":::info Prerequisites\n",
        "\n",
        "This guide assumes familiarity with the following concepts:\n",
        "- [Chat models](/docs/concepts/chat_models)\n",
        "- [LangChain Tools](/docs/concepts/tools)\n",
        "- [How to use a model to call tools](/docs/how_to/tool_calling)\n",
        "\n",
        ":::\n",
        "\n",
        "```\n",
        "\n",
        "In order to force our LLM to select a specific tool, we can use the `tool_choice` parameter to ensure certain behavior. First, let's define our model and tools:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { tool } from '@langchain/core/tools';\n",
        "import { z } from 'zod';\n",
        "\n",
        "const add = tool((input) => {\n",
        "    return `${input.a + input.b}`\n",
        "}, {\n",
        "    name: \"add\",\n",
        "    description: \"Adds a and b.\",\n",
        "    schema: z.object({\n",
        "        a: z.number(),\n",
        "        b: z.number(),\n",
        "    })\n",
        "})\n",
        "\n",
        "const multiply = tool((input) => {\n",
        "    return `${input.a * input.b}`\n",
        "}, {\n",
        "    name: \"Multiply\",\n",
        "    description: \"Multiplies a and b.\",\n",
        "    schema: z.object({\n",
        "        a: z.number(),\n",
        "        b: z.number(),\n",
        "    })\n",
        "})\n",
        "\n",
        "const tools = [add, multiply]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "import { ChatOpenAI } from '@langchain/openai';\n",
        "\n",
        "const llm = new ChatOpenAI({\n",
        "  model: \"gpt-3.5-turbo\",\n",
        "})"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "For example, we can force our tool to call the multiply tool by using the following code:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[\n",
            "  {\n",
            "    \"name\": \"Multiply\",\n",
            "    \"args\": {\n",
            "      \"a\": 2,\n",
            "      \"b\": 4\n",
            "    },\n",
            "    \"type\": \"tool_call\",\n",
            "    \"id\": \"call_d5isFbUkn17Wjr6yEtNz7dDF\"\n",
            "  }\n",
            "]\n"
          ]
        }
      ],
      "source": [
        "const llmForcedToMultiply = llm.bindTools(tools, {\n",
        "  tool_choice: \"Multiply\",\n",
        "})\n",
        "const multiplyResult = await llmForcedToMultiply.invoke(\"what is 2 + 4\");\n",
        "console.log(JSON.stringify(multiplyResult.tool_calls, null, 2));"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Even if we pass it something that doesn't require multiplcation - it will still call the tool!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can also just force our tool to select at least one of our tools by passing `\"any\"` (or for OpenAI models, the equivalent, `\"required\"`) to the `tool_choice` parameter."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "[\n",
            "  {\n",
            "    \"name\": \"add\",\n",
            "    \"args\": {\n",
            "      \"a\": 2,\n",
            "      \"b\": 3\n",
            "    },\n",
            "    \"type\": \"tool_call\",\n",
            "    \"id\": \"call_La72g7Aj0XHG0pfPX6Dwg2vT\"\n",
            "  }\n",
            "]\n"
          ]
        }
      ],
      "source": [
        "const llmForcedToUseTool = llm.bindTools(tools, {\n",
        "  tool_choice: \"any\",\n",
        "})\n",
        "const anyToolResult = await llmForcedToUseTool.invoke(\"What day is today?\");\n",
        "console.log(JSON.stringify(anyToolResult.tool_calls, null, 2));"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "TypeScript",
      "language": "typescript",
      "name": "tslab"
    },
    "language_info": {
      "codemirror_mode": {
        "mode": "typescript",
        "name": "javascript",
        "typescript": true
      },
      "file_extension": ".ts",
      "mimetype": "text/typescript",
      "name": "typescript",
      "version": "3.7.2"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
